Skip to content

[MLIR][NVVM] Support generating all the ldmatrix intrinsics from NVVM ops #148783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 12, 2025

Conversation

Pecco-314
Copy link
Contributor

Previously, the NVVM dialect's ldmatrix operation could only generate a limited subset of the available NVVM ldmatrix intrinsics. The intrinsics generating new ops introduced in BlackWell are not accessible through the NVVM ops. This commit extends the ldmatrix operation to support all available ldmatrix intrinsics.

@Pecco-314 Pecco-314 requested a review from grypp as a code owner July 15, 2025 06:30
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Pecco (Pecco-314)

Changes

Previously, the NVVM dialect's ldmatrix operation could only generate a limited subset of the available NVVM ldmatrix intrinsics. The intrinsics generating new ops introduced in BlackWell are not accessible through the NVVM ops. This commit extends the ldmatrix operation to support all available ldmatrix intrinsics.


Patch is 22.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148783.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+34-2)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+2-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+8-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+79-22)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+2-2)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+13-4)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (-11)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+37-7)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..cfb21e8331d05 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,6 +1990,35 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
+def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
+def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
+def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
+def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
+
+def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
+  [LdStMatrixShapeM8N8, LdStMatrixShapeM8N16, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def LdStMatrixEltTypeB16 : I32EnumAttrCase<"B16", 0, "b16">;
+def LdStMatrixEltTypeB8 : I32EnumAttrCase<"B8", 1, "b8">;
+def LdStMatrixEltTypeB8X16_B6X16_P32 : I32EnumAttrCase<"B8X16_B6X16_P32", 2, "b8x16.b6x16_p32">;
+def LdStMatrixEltTypeB8X16_B4X16_P64 : I32EnumAttrCase<"B8X16_B4X16_P64", 3, "b8x16.b4x16_p64">;
+
+def LdStMatrixEltType : I32EnumAttr<"LdStMatrixEltType", "Element type for ldmatrix and stmatrix",
+  [LdStMatrixEltTypeB16, LdStMatrixEltTypeB8,
+   LdStMatrixEltTypeB8X16_B6X16_P32, LdStMatrixEltTypeB8X16_B4X16_P64]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixEltTypeAttr : EnumAttr<NVVM_Dialect, LdStMatrixEltType, "ld_st_matrix_elttype"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
   Arguments<(ins LLVM_PointerShared:$ptr, 
                  Variadic<I32>:$sources, 
@@ -2021,13 +2050,16 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
 
 def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
   Results<(outs AnyType:$res)>,
-  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+  Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr: $num,
+                 MMALayoutAttr: $layout,
+                 LdStMatrixShapeAttr: $shape,
+                 LdStMatrixEltTypeAttr: $elttype)> {
 
   let summary = "cooperative matrix load";
 
   string llvmBuilder = [{
       auto operands = moduleTranslation.lookupValues(opInst.getOperands());
-      auto intId = getLdMatrixIntrinsicId($layout, $num);
+      auto intId = getLdMatrixIntrinsicId($layout, $num, $shape, $elttype);
       $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
   }];
 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 80b3d85488495..470dc2512a9ad 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -289,7 +289,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
         ldMatrixResultType, srcPtr,
         /*num=*/op.getNumTiles(),
         /*layout=*/op.getTranspose() ? NVVM::MMALayout::col
-                                     : NVVM::MMALayout::row);
+                                     : NVVM::MMALayout::row,
+        NVVM::LdStMatrixShape::M8N8, NVVM::LdStMatrixEltType::B16);
 
     // The ldmatrix operation returns either a single i32 value or a struct of
     // i32 values. Here we unpack those values and cast them back to their
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6e29b129e8835..93c155b67fb5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -806,14 +806,18 @@ LogicalResult NVVM::LdMatrixOp::verify() {
     return emitOpError("expected num attribute to be 1, 2 or 4");
 
   Type i32 = IntegerType::get(getContext(), 32);
-  if (getNum() == 1 && getType() != i32)
+  uint32_t num = getNum();
+  if (getShape() == LdStMatrixShape::M16N16) {
+    num *= 2;
+  }
+  if (num == 1 && getType() != i32)
     return emitOpError("expected destination type is i32");
-  if (getNum() == 2 || getNum() == 4) {
+  if (num == 2 || num == 4) {
     Type dstType = LLVM::LLVMStructType::getLiteral(
-        getContext(), SmallVector<Type>(getNum(), i32));
+        getContext(), SmallVector<Type>(num, i32));
     if (getType() != dstType)
       return emitOpError("expected destination type is a structure of ")
-             << getNum() << " elements of type i32";
+             << num << " elements of type i32";
   }
   return success();
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eecca64c4bf81..5d13933519c54 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -134,33 +134,90 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
   llvm_unreachable("unsupported vote kind");
 }
 
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
-                                                  int32_t num) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+                       NVVM::LdStMatrixShape shape,
+                       NVVM::LdStMatrixEltType elttype) {
   if (layout == NVVM::MMALayout::row) {
-    switch (num) {
-    case 1:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
-    case 2:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
-    case 4:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
-    default:
-      llvm_unreachable("unsupported number of matrix");
+    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+        elttype == NVVM::LdStMatrixEltType::B16) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+      case 4:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M8N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+      case 4:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+      }
     }
-
   } else {
-    switch (num) {
-    case 1:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
-    case 2:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
-    case 4:
-      return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
-    default:
-      llvm_unreachable("unsupported number of matrix");
+    if (shape == NVVM::LdStMatrixShape::M8N8 &&
+        elttype == NVVM::LdStMatrixEltType::B16) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+      case 2:
+        return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+      }
+    } else if (shape == NVVM::LdStMatrixShape::M16N16 &&
+               elttype == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+      case 2:
+        return llvm::Intrinsic::
+            nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+      }
     }
   }
+  llvm_unreachable("unsupported matrix configuration");
 }
 
 /// Return the intrinsic ID associated with st.bulk for the given address type.
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index d0bc806e0aa8c..75a556f471373 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -159,7 +159,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
 // CHECK-LABEL: @ldmatrix_x4
 func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32)
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 4 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> !llvm.struct<(i32, i32, i32, i32)>
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16>
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast
@@ -179,7 +179,7 @@ func.func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) ->  vector<4x2xf16> {
 // CHECK-LABEL: @ldmatrix_x1
 func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
   %c0  = arith.constant 0 : index
-  // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} {{.*}} -> i32
+  // CHECK: nvvm.ldmatrix {{%.+}} {elttype = #nvvm.ld_st_matrix_elttype<b16>, layout = #nvvm.mma_layout<row>, num = 1 : i32, shape = #nvvm.ld_st_matrix_shape<m8n8>} : {{.*}} -> i32
   %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16>
   // CHECK: llvm.bitcast
   // CHECK: llvm.insertvalue
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bd1106e304c60..f9def0877d71a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1116,7 +1116,7 @@ llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr) -> i32
   llvm.return
 }
 
@@ -1124,7 +1124,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
-  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+  %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   llvm.return
 }
 
@@ -1132,7 +1132,7 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32)>
   llvm.return
 }
 
@@ -1140,10 +1140,19 @@ llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
 
 llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
   // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
-  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   llvm.return
 }
 
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<3>) {
+  // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
+  %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m16n16>, elttype = #nvvm.ld_st_matrix_elttype<b8>} : (!llvm.ptr<3>) -> i32
+  llvm.return
+}
+
+
 // -----
 
 llvm.func @caller() {
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index c7fa41c98ac92..6a4edd0d22a08 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -385,17 +385,6 @@ llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
   llvm.return
 }
 
-// CHECK-LABEL: llvm.func @ld_matrix
-llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<3>) -> i32
-  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
-  // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-  llvm.return
-}
-
 // CHECK-LABEL: llvm.func @redux_sync
 llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 {
   // CHECK: nvvm.redux.sync  add %{{.*}}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..89429a762db92 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
 // CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> i32
+  %l1 = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
-   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> i32
+  %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}})
+  %l1t = nvvm.ldmatrix %arg0 {num = 1: i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> i32
   // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>, shape =#nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}})
-  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>, elttype = #nvvm.ld_st_matrix_elttype<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l2 = nvvm.ldmatrix %arg0 {num = 2: i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b6_l4 = nvvm.ldmatrix %arg0{num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>,elttype =#nvvm.ld_st_matrix_elttype<b8x16.b6x16_p32>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>, shape =#nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64.p3(ptr addrspace(3) %{{.*}})
+  %m8n16_b4_l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n16>, elttype = #nvvm.ld_st_matrix_elttype<b8x16.b4x16_p64>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m16n16.x1.trans.b8.p3(ptr add...
[truncated]

@Pecco-314
Copy link
Contributor Author

@grypp @durga4github @schwarzschild-radius This PR has been pending review for some time. Could someone please help review this change?

%l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
llvm.return
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning this up and consolidating all these in the nvvm specific invalid-tests file!

@@ -559,17 +559,47 @@ llvm.func @llvm_nvvm_cp_async_bulk_wait_group() {
// CHECK-LABEL: @ld_matrix
llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional]:
Now that we are re-writing this anyway, it would be good to move this into its own separate file in the nvvm/ directory below? say, ldmatrix.mlir

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. But I think it is better to conduct a unified refactoring, splitting this file into multiple files, rather than doing it here.

@durga4github
Copy link
Contributor

I have an optional request. With or without that addressed, the changes here LGTM.

@schwarzschild-radius , kindly help with a round of review.

@@ -2032,13 +2032,16 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,

def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
Results<(outs AnyType:$res)>,
Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
Arguments<(ins LLVM_PointerShared: $ptr, I32Attr: $num,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the space between : and the variable

// expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 2 elements of type i32}}
%l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>, eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> i32
llvm.return
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a newline at the end of the file?

Copy link
Contributor

@schwarzschild-radius schwarzschild-radius left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Just a minor nit! LGTM!

@peterbell10 peterbell10 merged commit 24f5385 into llvm:main Aug 12, 2025
9 checks passed
Copy link

@Pecco-314 Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants